{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Binary segmentation with the Lovász Hinge\n",
    "## Tensorflow version, see PyTorch version for more details & use in training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:10.230570Z",
     "start_time": "2019-02-26T17:10:10.212590Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:10.246446Z",
     "start_time": "2019-02-26T17:10:10.232957Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from __future__ import division, print_function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:10.258048Z",
     "start_time": "2019-02-26T17:10:10.247788Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:14.921553Z",
     "start_time": "2019-02-26T17:10:10.260680Z"
    }
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:14.951312Z",
     "start_time": "2019-02-26T17:10:14.924015Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import lovasz_losses_tf as L"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.184676Z",
     "start_time": "2019-02-26T17:10:14.952563Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.228515Z",
     "start_time": "2019-02-26T17:10:15.186524Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# add parent path to pythonpath to import demo_utils\n",
    "import sys\n",
    "sys.path.append('../demo_helpers')\n",
    "from demo_utils import pil_grid, dummy_triangles\n",
    "from demo_utils_tf import define_scope"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.704993Z",
     "start_time": "2019-02-26T17:10:15.230535Z"
    }
   },
   "outputs": [],
   "source": [
    "# add pytorch folder to path for comparison\n",
    "sys.path.append('../pytorch')\n",
    "import lovasz_losses as L_pytorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ground truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.739399Z",
     "start_time": "2019-02-26T17:10:15.706475Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "H = W = 200\n",
    "IGNORE = 255\n",
    "classes = [0, 255, 1]\n",
    "C = len([c for c in classes if c != IGNORE])\n",
    "B = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.803022Z",
     "start_time": "2019-02-26T17:10:15.740518Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAADKCAIAAAAASlGDAAAKqklEQVR4nO3d4XXbOBOFYfk7W4hLSSkp5ZaypaQUd8LvB7OKYokUCGAwM8D7/Niz8crgHC9zeUHT8u0GGNi2zXsEzGnbtv95zwAA1xBbAJIhtgAkQ2wBSIbYApAMsQUgGWILQDLEFoBkiC0AyRBbAJIhtgAkQ2wBSIbYApAMsQUgGWILQDLEFoBkiC0AyRBbAJJxjq2vr1++AwDAZV9fvwiv+fBe8jAS4r3kPz9/3KhdAIp9eA/w2z229hRDdtu2fXxEObswkxBta3dPK2oXgHOxroePmUXtSo22BSOB2tbuMaqoXQBeing9/BZY1K6MaFswEq5tvUTtAvAo6PXwZVRRuxKhbcHItm1xT6yjkkV4pUBswUjoTeJRPLFnBBYX/Xp4ElLUrshoWzASum3tTrKJ2gWsKcf18DyhqF0B0bZgJEHb2p0HE7ULWEqm6+HbeKJ2xUHbgpHQD0C8VFKsCK8IiC0YSbNJvCuJJPaMwNxSXg8Lg4na5Yi2BSP52tauMI+oXcCUEl8Py1OJ2jUebQtGsratXXkYUbuAmaS/Hl6KJGrXMLQtGMndtipQu4AJzHA9rAgjapc12haM5Hvc9EhdjSK87BBbMDLPJrEugNgzAhlNdT2sjiFqV3e0LRiZp23tqtOH2gUkMuH1sCWDqF290LZgZLa2tWuJHmoXEN+018PGAKJ2NaJtwcicbWvXmDvULiCsya+H7elD7apD24KReR43PdGlNxFeVxFbMDLzJvGuS+KwZwTiWOV62Ct3qF2FaFswskTb2vWKG2oX4G6t62HH0KF2naNtwchCbas7ahfgZbnrYfe4oXa9RNuCkSUegHhmUZQIr2+ILRhZdJNoETHsGYFh1r0eGgUNtWtH24KRRdvWzihfqF2AtdWvh3Yps3jtom3ByNJta2cXLulql7wHAApxPbzdjCMmfu3S07+0o23BCG3rN9NkiVy7ZJNZWI3GHo7r4R/W+RKqdqngIy1oW0tR8QfbLfq46YkVkkvFH2xBbK1DBq88sW3bPz3Wmcfn54/Ie7pG8h4Ai1PxB89xPXxhvjv0avivdWhbi9DwNdkkHpomudTpNVcRW4vQ8KPwncRDEW5CtZP3AJibnI5CbDkYcPtMxadU4cuAOIitQ0kf5hJJhCHkdxRi60yuraKun0lXXw+MpIOPE1tvpPihRRFAGEuuhyC23ovcudRwAlV/IuCL2CoS8M25RO5gXjr9r8RWqTjJJQILruS9PrF1gftuUf3OmF7rAH2p4DXE1jVevz5DBA1ikPcAN2KrwuDkksGJ0n1BoAuVvYzYqjFmtyjyBcEoxuLEVqXuyfVYuGR5ftitDFTTlRcTW/UskkvECqKS9wB3xFaTvsn1r/3eU9YHAK7Txdfz7qaturwh6oDAAloo0srEVgctyUVgAVcRW27GB5YGHw94R1WfRWz1calw0bCQi4ItS2x1U5JcjoElrwMDr6jhc4mtnk6Si4aFpOQ9wDNiq7Pn5IoQWPIeAHiktk8ntvq7J1eEwAJaKOSaxJYJAgt4ST0WIbY6k/cAz+Q9AJKS9wBHiK1u5D0AEJw6rUNsdSDvAU7IewAkpcALEltN5D0AkIK6rsY7QNST9wBvyXsAJCXvAc7RtmrIewAgEfVekNi6Rt4DALnIYE1iq5S8B7hK3gMgKXkP8Bax9Z68BwCSks2yxNYZeQ9QTd4DICmFXOobYus1eQ8ApCbLxYmt7+Q9QDt5D4Ck5D1AIWLrD3kPAMxBxuvzuOntxu9/hoH23+c0mIKtc2L1tiXvAbqT9wC4e0yu7r8MOCYNOcq6sSXvAbCU4BEm7wEuWTG25D2AHXkPgBLBI6yaRh1ordiS9wBYRPmNrWkiTAOPtUpsyXuAAeQ9ABrdI2xwfinACpcsEVvyHgC4ZJoKZuTDewBb8h5gmJ9fv0Kd39u2fXxMfnadMHr6weJ/sQKscMm2bdO2LXkPAHQXsIXJ46ATxpa8B/DyFaxwwVSXCFOPScabKrbkPYCXn/+dwSTXmrxamIYd6W+TxJa8BwDufH+spzzC1Hagxk9vkT625D2Au59//yWhcOHOroWp41rXJY4teQ8QFsmFZ98iTG6DdJDyHSBEZv3nZ7a3GUAEjdtY9ZmiXrLYUoAvWQrp3jUFWch7gFui2FKMr1ciJBeO/Jv8gYkE97bkPUBY7BCxptBtS2RWGwoXnlVXrTiXyaA/NSbvAeIrP4dcvqu47M8kxr9U1MXW8/nm9d3qoD+TKO8BgFn1yqyb6w9IxooteQ+QhW43ff4ovLDzGBesDY6wKLEl7wEy+ixOLuDWtWqdGBBh/ncf5D1AcHr3gsLkGly4uLcVUEVs9boN3/H027bN88SS47FDUu0nBkwuYiuaMVWrRON56HZLXi5HDUO9F2S3CAtGTzy07yJHXw81+HjeNPBYJck1rHCt2bYiXzyutq3xT2kVnpxD25aGHcmJvAcowXcV1xQ/s25XWtiI2NKAY4wl7wFeYquI7AqvqbY1XqarDyHvAa4KslVkkxhHiqpVfk4abhJltK4ZeQ/QS0nnYquII4Mzq+487B9b6r5ib/IewBq7RdxdqlojM6vlwtkzttRxrU7kPYCXt8lF4YKLLmddn7sP6rJKG3kPENDbzmWXXAve2wrYcENVrV4nW597W2pfIvDhUmO3iEJ2mWVxaWy6HqrTEF7rL+I8uYwKF23LXXnVMsosu1Orsm2p6xwWC+LuvHNxkwt9RXzjGjUfsn0FXMVuEUcCvsfDWxdqvK6vXvEpsHOSXN3POTaJvgp3iF0ya/x7IhW1LTW/ABHQudCR472FN9dDFXwEuRwlV9+zcLW2Fep6YF21fG+Gnr1NoEioeQ1ILmLLUUlsVWRWkG/dbNt2+HsSNXAODBbk/IMFi8z6/PwR6pyJ8iswEAEPQ+BR2JMh9G+lhp2jMzLUZgdX9apa0erVN7StdfGNxQWdZ1bkqHq00E1TvPQyudpPX27Jj/e2ah1lVpa02p3dksciXp6yEf4SJpL3yxV8M3iETSLYLS7hsWpljKpHtC3cbq/OY4Isl/Md4j2zktarb4gt/EZyzern1689rSYIrB2xhT+mOa1Xc1615vvfSmzhL99OcQpXdvIewAKxhe9IrlxOqpbGTTEUsYUX5ttWYCbEFl57TC4K1wnfL86CVetGbOEEyRXZ+XcGNW4QBzxuijM8iRrNY1Tp4DVHH58GbQtv3P+ekF++ZnrwqhFtCwjtKKp08Pqjj8+EtoX3KFwuKuqVTAYJh9hCEZJrmOofxFH/WYIitlCKGyvP+oZ4eVqp41ETIrZwwf6XisLVV5efc1afWXJY6P0n0cseW+d/zRZ5d9OWBK/OKRV8ZGK8uylqsFts1PdRBvVaKA8egECN/TFU8uuSLl8utS+RH7GFSmRWObuvlYzWjY3YAqx0Tyud/nEdxBZQ6eR+PFXUFLEFdGOaVjr941KILaCDwfVKIw8WD7EF1HPZDGr8IYPhuS2g0sjM0rAjZUBsAZnIe4AIiC0gOj39y+KILSAHeQ8QB7EFhCbvAQIitoAE5D1AKMQWEJce/ok7YgsITd4DBERsAUHJe4CwiC0gLnkPACxk2zbvETAn3pQZQD7EFoBkiC0AyRBbAJIhtgAk88F3fADkQtsCkMz/AWx78dIaTbWCAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=402x202 at 0x7F2BB85F5E48>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.seed(18)\n",
    "labels_ = [dummy_triangles(H, classes) for b in range(B)]\n",
    "labels = np.stack(map(np.array, labels_)).astype(np.uint8)\n",
    "pil_grid(labels_, 5, margin=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.847418Z",
     "start_time": "2019-02-26T17:10:15.804496Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "np.random.seed(57)\n",
    "feats = np.zeros(labels.shape)\n",
    "feats[labels == 0] = np.random.normal(-1, 2.5, size=feats.shape)[labels == 0]\n",
    "feats[labels == 1] = np.random.normal(1, 2.5, size=feats.shape)[labels == 1]\n",
    "feats[labels == 255] = np.random.normal(0, 5, size=feats.shape)[labels == 255]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.907285Z",
     "start_time": "2019-02-26T17:10:15.848545Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.945032Z",
     "start_time": "2019-02-26T17:10:15.909061Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "labels_tf = tf.placeholder(np.int64, shape=(None, H, W), name='labels')\n",
    "feats_tf = tf.placeholder(np.float32, shape=(None, H, W), name='features')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:15.973877Z",
     "start_time": "2019-02-26T17:10:15.946100Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "feed_dict = {labels_tf: labels, \n",
    "             feats_tf: feats}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.013115Z",
     "start_time": "2019-02-26T17:10:15.974973Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model:\n",
    "    '''\n",
    "    Simple linear model\n",
    "    '''\n",
    "    \n",
    "    def __init__(self, feats):\n",
    "        self.feats = feats\n",
    "        self.scores\n",
    "        self.predict\n",
    "\n",
    "    @define_scope\n",
    "    def scores(self):\n",
    "        x = self.feats\n",
    "        bias = tf.Variable(0., x.dtype, name=\"bias\")\n",
    "        return tf.add(x, bias, name=\"scores\")\n",
    "    \n",
    "    @define_scope\n",
    "    def probas(self):\n",
    "        return tf.nn.sigmoid(self.scores)\n",
    "    \n",
    "    @define_scope\n",
    "    def predict(self):\n",
    "        return tf.greater_equal(self.scores, 0, name=\"predict\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.048152Z",
     "start_time": "2019-02-26T17:10:16.014724Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "m = Model(feats_tf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.078972Z",
     "start_time": "2019-02-26T17:10:16.049354Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lovász-Hinge value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.244800Z",
     "start_time": "2019-02-26T17:10:16.081932Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss = L.lovasz_hinge(m.scores, labels_tf, ignore=IGNORE, per_image=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.307823Z",
     "start_time": "2019-02-26T17:10:16.246970Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.4664207"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss.eval(feed_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Gradient w.r.t. bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.592486Z",
     "start_time": "2019-02-26T17:10:16.309527Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before flag parsing goes to stderr.\n",
      "W0226 18:10:16.500215 139827989141312 deprecation.py:323] From /esat/visicsdata/mberman/miniconda2/envs/lamlearn/lib/python3.6/site-packages/tensorflow/python/ops/array_grad.py:425: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.cast instead.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.57314444"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.gradients(loss, tf.trainable_variables()[0])[0].eval(feed_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyTorch reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "should give the same results..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.703262Z",
     "start_time": "2019-02-26T17:10:16.594132Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from torch import nn\n",
    "\n",
    "class ModelPytorch(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ModelPytorch, self).__init__()\n",
    "        self.bias = nn.Parameter(torch.Tensor([0]))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x + self.bias\n",
    "\n",
    "m_torch = ModelPytorch()\n",
    "out = m_torch(Variable(torch.from_numpy(feats.astype(np.float32))))\n",
    "loss_torch = L_pytorch.lovasz_hinge(out, torch.from_numpy(labels), ignore=IGNORE, per_image=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.747415Z",
     "start_time": "2019-02-26T17:10:16.704610Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.4664201736450195"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_torch.data.numpy().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Gradient w.r.t. bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.800181Z",
     "start_time": "2019-02-26T17:10:16.748739Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss_torch.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.834715Z",
     "start_time": "2019-02-26T17:10:16.801857Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5731444358825684"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m_torch.bias.grad.data.numpy().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lovász-Sigmoid value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:16.981434Z",
     "start_time": "2019-02-26T17:10:16.836137Z"
    }
   },
   "outputs": [],
   "source": [
    "loss = L.lovasz_softmax(m.probas, labels_tf, classes=[1], ignore=IGNORE, per_image=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:17.040635Z",
     "start_time": "2019-02-26T17:10:16.982684Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.80367076"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss.eval(feed_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:41.047477Z",
     "start_time": "2019-02-26T17:10:40.733537Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.032584023"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.gradients(loss, tf.trainable_variables()[0])[0].eval(feed_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T16:58:05.605057Z",
     "start_time": "2019-02-26T16:58:05.421586Z"
    }
   },
   "source": [
    "### PyTorch reference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:17.073915Z",
     "start_time": "2019-02-26T17:10:17.042222Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sigmoid_out = torch.sigmoid(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:17.133917Z",
     "start_time": "2019-02-26T17:10:17.075353Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8037, grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_torch = L_pytorch.lovasz_softmax(sigmoid_out, torch.from_numpy(labels), classes=[1], ignore=IGNORE, per_image=True)\n",
    "loss_torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:17.172649Z",
     "start_time": "2019-02-26T17:10:17.135408Z"
    }
   },
   "outputs": [],
   "source": [
    "m_torch.bias.grad.data.fill_(0)\n",
    "loss_torch.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-02-26T17:10:17.207204Z",
     "start_time": "2019-02-26T17:10:17.173987Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.03258401155471802"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m_torch.bias.grad.data.numpy().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [conda env:lamlearn]",
   "language": "python",
   "name": "conda-env-lamlearn-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
